{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Positive vs. Negative Sentiment Classification\n",
"\n",
"In this notebook, we demonstrate how to interpret a sentiment classification model using SHAP. The goal is to understand how individual words in a movie review influence the model's prediction of positive or negative sentiment."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/aribaa/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n",
" warnings.warn(\n",
"/Users/aribaa/Library/Python/3.9/lib/python/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import datasets\n",
"import numpy as np\n",
"import transformers\n",
"\n",
"import shap"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the IMDB movie review dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Load IMDB dataset (movie reviews labeled as positive/negative)\n",
"dataset = datasets.load_dataset(\"imdb\", split=\"test\")\n",
"\n",
"# shorten the strings to fit into the pipeline model\n",
"short_data = [v[:500] for v in dataset[\"text\"][:20]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load and run a sentiment analysis pipeline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).\n",
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
"Device set to use mps:0\n",
"/Users/aribaa/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/text_classification.py:111: UserWarning: `return_all_scores` is now deprecated, if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"[[{'label': 'NEGATIVE', 'score': 0.07581914216279984},\n",
" {'label': 'POSITIVE', 'score': 0.924180805683136}],\n",
" [{'label': 'NEGATIVE', 'score': 0.01834261603653431},\n",
" {'label': 'POSITIVE', 'score': 0.9816573858261108}]]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load pretrained sentiment analysis model from HuggingFace\n",
"# Note: Model will be downloaded on first run\n",
"classifier = transformers.pipeline(\"sentiment-analysis\", return_all_scores=True)\n",
"classifier(short_data[:2])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explain the sentiment analysis pipeline"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Create SHAP explainer to compute word-level importance\n",
"explainer = shap.Explainer(classifier)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"PartitionExplainer explainer: 3it [00:11, 11.25s/it] \n"
]
}
],
"source": [
"# explain the predictions of the pipeline on the first two samples\n",
"shap_values = explainer(short_data[:2])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"
\n",
"